2022-11-15

Generative models

  • We have described how, when using squared loss, the conditional expectation/probabilities provide the best approach to developing a decision rule.

  • In a binary case, the smallest true error we can achieve is determined by Bayes’ rule, which is a decision rule based on the true conditional probability:

\[ p(\mathbf{x}) = \mbox{Pr}(Y=1 \mid \mathbf{X}=\mathbf{x}) \]

  • We have described several approaches to estimating \(p(\mathbf{x})\).

Generative models

  • In all these approaches, we estimate the conditional probability directly and do not consider the distribution of the predictors.

  • In machine learning, these are referred to as discriminative approaches.

  • However, Bayes’ theorem tells us that knowing the distribution of the predictors \(\mathbf{X}\) may be useful.

  • Methods that model the joint distribution of \(Y\) and \(\mathbf{X}\) are referred to as generative models (we model how the entire data, \(\mathbf{X}\) and \(Y\), are generated).

Generative models

  • We start by describing the most general generative model, Naive Bayes, and then proceed to describe two specific cases, quadratic discriminant analysis (QDA) and linear discriminant analysis (LDA).

Naive Bayes

  • Recall that Bayes rule tells us that we can rewrite \(p(\mathbf{x})\) like this:

\[ \mbox{Pr}(Y=1|\mathbf{X}=\mathbf{x}) = \frac{f_{\mathbf{X}|Y=1}(\mathbf{x}) \mbox{Pr}(Y=1)} { f_{\mathbf{X}|Y=0}(\mathbf{x})\mbox{Pr}(Y=0) + f_{\mathbf{X}|Y=1}(\mathbf{x})\mbox{Pr}(Y=1) } \]

  • with \(f_{\mathbf{X}|Y=1}\) and \(f_{\mathbf{X}|Y=0}\) representing the distribution functions of the predictor \(\mathbf{X}\) for the two classes \(Y=1\) and \(Y=0\).

Naive Bayes

  • The formula implies that if we can estimate these conditional distributions of the predictors, we can develop a powerful decision rule.

  • However, this is a big if.

  • As we go forward, we will encounter examples in which \(\mathbf{X}\) has many dimensions and we do not have much information about the distribution.

  • In these cases, Naive Bayes will be practically impossible to implement.

Naive Bayes

  • However, there are instances in which we have a small number of predictors (not much more than 2) and many categories in which generative models can be quite powerful.

  • We describe two specific examples and use our previously described case studies to illustrate them.

  • Let’s start with a very simple and uninteresting, yet illustrative, case: the example related to predicting sex from height.

Naive Bayes

library(tidyverse) 
library(caret) 
library(dslabs) 
data("heights") 
y <- heights$height 
set.seed(1995) 
test_index <- createDataPartition(y, times = 1, p = 0.5, list = FALSE) 
train_set <- heights |> slice(-test_index) 
test_set <- heights |> slice(test_index) 
  • In this case, the Naive Bayes approach is particularly appropriate because we know that the normal distribution is a good approximation for the conditional distributions of height given sex for both classes \(Y=1\) (female) and \(Y=0\) (male).

Naive Bayes

  • This implies that we can approximate the conditional distributions \(f_{X|Y=1}\) and \(f_{X|Y=0}\) by simply estimating averages and standard deviations from the data:
params <- train_set |>  
  group_by(sex) |>  
  summarize(avg = mean(height), sd = sd(height)) 
params 
#> # A tibble: 2 × 3
#>   sex      avg    sd
#>   <fct>  <dbl> <dbl>
#> 1 Female  64.8  4.14
#> 2 Male    69.2  3.57

Naive Bayes

  • The prevalence, which we will denote with \(\pi = \mbox{Pr}(Y=1)\), can be estimated from the data with:
pi <- train_set |> summarize(pi=mean(sex=="Female")) |> pull(pi) 
pi 
#> [1] 0.212

Naive Bayes

  • Now we can use our estimates of average and standard deviation to get an actual rule:
x <- test_set$height 
f0 <- dnorm(x, params$avg[2], params$sd[2]) 
f1 <- dnorm(x, params$avg[1], params$sd[1]) 
p_hat_bayes <- f1*pi / (f1*pi + f0*(1 - pi)) 
  • Our Naive Bayes estimate \(\hat{p}(x)\) looks a lot like our logistic regression estimate:

Naive Bayes

Naive Bayes

  • In fact, we can show that the Naive Bayes approach is similar to the logistic regression prediction mathematically.

  • However, we leave the demonstration to a more advanced text, such as the Elements of Statistical Learning.

  • We can see that they are similar empirically by comparing the two resulting curves.

Controlling prevalence

  • One useful feature of the Naive Bayes approach is that it includes a parameter to account for differences in prevalence.

  • Using our sample, we estimated \(f_{X|Y=1}\), \(f_{X|Y=0}\) and \(\pi\).

  • If we use hats to denote the estimates, we can write \(\hat{p}(x)\) as:

\[ \hat{p}(x)= \frac{\hat{f}_{X|Y=1}(x) \hat{\pi}} { \hat{f}_{X|Y=0}(x)(1-\hat{\pi}) + \hat{f}_{X|Y=1}(x)\hat{\pi} } \]

Controlling prevalence

  • As we discussed earlier, our sample has a much lower prevalence, 0.21, than the general population.

  • So if we use the rule \(\hat{p}(x)>0.5\) to predict females, our accuracy will be affected due to the low sensitivity:

y_hat_bayes <- ifelse(p_hat_bayes > 0.5, "Female", "Male") 
sensitivity(data = factor(y_hat_bayes), reference = factor(test_set$sex)) 
#> [1] 0.213

Controlling prevalence

  • Again, this is because the algorithm gives more weight to specificity to account for the low prevalence:
specificity(data = factor(y_hat_bayes), reference = factor(test_set$sex)) 
#> [1] 0.967
  • This is due mainly to the fact that \(\hat{\pi}\) is substantially less than 0.5, so we tend to predict Male more often.

Controlling prevalence

  • It makes sense for a machine learning algorithm to do this in our sample because we do have a higher percentage of males.

  • But if we were to extrapolate this to a general population, our overall accuracy would be affected by the low sensitivity.

  • The Naive Bayes approach gives us a direct way to correct this since we can simply force \(\hat{\pi}\) to be whatever value we want it to be.

Controlling prevalence

  • So to balance specificity and sensitivity, instead of changing the cutoff in the decision rule, we could simply change \(\hat{\pi}\) to 0.5 like this:
p_hat_bayes_unbiased <- f1 * 0.5 / (f1 * 0.5 + f0 * (1 - 0.5))  
y_hat_bayes_unbiased <- ifelse(p_hat_bayes_unbiased> 0.5, "Female", "Male") 
  • Note the difference in sensitivity with a better balance:
sensitivity(factor(y_hat_bayes_unbiased), factor(test_set$sex)) 
#> [1] 0.693
specificity(factor(y_hat_bayes_unbiased), factor(test_set$sex)) 
#> [1] 0.832

Controlling prevalence

  • The new rule also gives us a very intuitive cutoff between 66-67, which is about the middle of the female and male average heights:

Quadratic discriminant analysis

  • Quadratic Discriminant Analysis (QDA) is a version of Naive Bayes in which we assume that the distributions \(p_{\mathbf{X}|Y=1}(x)\) and \(p_{\mathbf{X}|Y=0}(\mathbf{x})\) are multivariate normal.

  • The simple example we described in the previous section is actually QDA.

  • Let’s now look at a slightly more complicated case: the 2 or 7 example.

data("mnist_27") 
  • In this case, we have two predictors so we assume each one is bivariate normal.

Quadratic discriminant analysis

  • This implies that we need to estimate two averages, two standard deviations, and a correlation for each case \(Y=1\) and \(Y=0\).

  • Once we have these, we can approximate the distributions \(f_{X_1,X_2|Y=1}\) and \(f_{X_1, X_2|Y=0}\).

Quadratic discriminant analysis

  • We can easily estimate parameters from the data:
params <- mnist_27$train |>  
  group_by(y) |>  
  summarize(avg_1 = mean(x_1), avg_2 = mean(x_2),  
            sd_1= sd(x_1), sd_2 = sd(x_2),  
            r = cor(x_1, x_2)) 
params 
#> # A tibble: 2 × 6
#>   y     avg_1 avg_2   sd_1   sd_2     r
#>   <fct> <dbl> <dbl>  <dbl>  <dbl> <dbl>
#> 1 2     0.129 0.283 0.0702 0.0578 0.401
#> 2 7     0.234 0.288 0.0719 0.105  0.455

Quadratic discriminant analysis

  • Here we provide a visual way of showing the approach.

  • We plot the data and use contour plots to give an idea of what the two estimated normal densities look like (we show the curve representing a region that includes 95% of the points).

mnist_27$train |> mutate(y = factor(y)) |>  
  ggplot(aes(x_1, x_2, fill = y, color=y)) +  
  geom_point(show.legend = FALSE) +  
  stat_ellipse(type="norm", lwd = 1.5) 

Quadratic discriminant analysis

Quadratic discriminant analysis

  • This defines the following estimate of \(f(x_1, x_2)\).

  • We can use the train function from the caret package to fit the model and obtain predictors:

library(caret) 
train_qda <- train(y ~ ., method = "qda", data = mnist_27$train) 

Quadratic discriminant analysis

  • We see that we obtain relatively good accuracy:
y_hat <- predict(train_qda, mnist_27$test) 
confusionMatrix(y_hat, mnist_27$test$y)$overall["Accuracy"] 
#> Accuracy 
#>     0.82
  • The estimated conditional probability looks relatively good, although it does not fit as well as the kernel smoothers

Quadratic discriminant analysis

Quadratic discriminant analysis

  • One reason QDA does not work as well as the kernel methods is perhaps because the assumption of normality does not quite hold.

  • Although for the 2s it seems reasonable, for the 7s it does seem to be off.

  • Notice the slight curvature in the points for the 7s:

mnist_27$train |> mutate(y = factor(y)) |>  
  ggplot(aes(x_1, x_2, fill = y, color=y)) +  
  geom_point(show.legend = FALSE) +  
  stat_ellipse(type="norm") + 
  facet_wrap(~y) 

Quadratic discriminant analysis

Quadratic discriminant analysis

  • QDA can work well here, but it becomes harder to use as the number of predictors increases.

  • Here we have 2 predictors and had to compute 4 means, 4 SDs, and 2 correlations.

  • How many parameters would we have if instead of 2 predictors, we had 10?

  • The main problem comes from estimating correlations for 10 predictors.

  • With 10, we have 45 correlations for each class.

Quadratic discriminant analysis

  • In general, the formula is \(K\times p(p-1)/2\), which gets big fast.

  • Once the number of parameters approaches the size of our data, the method becomes impractical due to overfitting.

Linear discriminant analysis

  • A relatively simple solution to the problem of having too many parameters is to assume that the correlation structure is the same for all classes, which reduces the number of parameters we need to estimate.

Linear discriminant analysis

  • In this case, we would compute just one pair of standard deviations and one correlation,
params <- mnist_27$train |>  
  group_by(y) |>  
  summarize(avg_1 = mean(x_1), avg_2 = mean(x_2),  
            sd_1= sd(x_1), sd_2 = sd(x_2),  
            r = cor(x_1,x_2)) 
params <- params |> mutate(sd_1 = mean(sd_1), sd_2=mean(sd_2), r=mean(r)) 
params  
#> # A tibble: 2 × 6
#>   y     avg_1 avg_2   sd_1   sd_2     r
#>   <fct> <dbl> <dbl>  <dbl>  <dbl> <dbl>
#> 1 2     0.129 0.283 0.0710 0.0813 0.428
#> 2 7     0.234 0.288 0.0710 0.0813 0.428

Linear discriminant analysis

  • and the distributions looks like this:

Linear discriminant analysis

  • Now the size of the ellipses as well as the angle are the same.

  • This is because they have the same standard deviations and correlations.

  • We can fit the LDA model using caret:

train_lda <- train(y ~ ., method = "lda", data = mnist_27$train) 
y_hat <- predict(train_lda, mnist_27$test) 
confusionMatrix(y_hat, mnist_27$test$y)$overall["Accuracy"] 
#> Accuracy 
#>     0.75

Linear discriminant analysis

  • When we force this assumption, we can show mathematically that the boundary is a line, just as with logistic regression.

  • For this reason, we call the method linear discriminant analysis (LDA).

  • Similarly, for QDA, we can show that the boundary must be a quadratic function.

Linear discriminant analysis

Linear discriminant analysis

  • In the case of LDA, the lack of flexibility does not permit us to capture the non-linearity in the true conditional probability function.

Connection to distance

  • The normal density is:

\[ p(x) = \frac{1}{\sqrt{2\pi} \sigma} \exp\left\{ - \frac{(x-\mu)^2}{\sigma^2}\right\} \]

  • If we remove the constant \(1/(\sqrt{2\pi} \sigma)\) and then take the log, we get:

\[ - \frac{(x-\mu)^2}{\sigma^2} \]

  • which is the negative of a distance squared scaled by the standard deviation.

  • For higher dimensions, the same is true except the scaling is more complex and involves correlations.

Case study: more than three classes

  • We can generate an example from the MNIST data with three categories 1s, 2s, and 7s.
index_127 <- sample(which(mnist$train$labels %in% c(1,2,7)), 2000) 
y <- mnist$train$labels[index_127]  
x <- mnist$train$images[index_127,] 
index_train <- createDataPartition(y, p=0.8, list = FALSE) 
row_column <- expand.grid(row=1:28, col=1:28)  
upper_left_ind <- which(row_column$col <= 14 & row_column$row <= 14) 
lower_right_ind <- which(row_column$col > 14 & row_column$row > 14) 
x <- x > 200  
x <- cbind(rowSums(x[ ,upper_left_ind])/rowSums(x),  
           rowSums(x[ ,lower_right_ind])/rowSums(x))  
train_set <- data.frame(y = factor(y[index_train]), 
                        x_1 = x[index_train,1], x_2 = x[index_train,2]) 
test_set <- data.frame(y = factor(y[-index_train]), 
                       x_1 = x[-index_train,1], x_2 = x[-index_train,2]) 

Case study: more than three classes

  • Here is the training data:

Case study: more than three classes

  • We can use the caret package to train the QDA model:
train_qda <- train(y ~ ., method = "qda", data = train_set) 
  • Now we estimate three conditional probabilities (although they have to add to 1):
predict(train_qda, test_set, type = "prob") |> head() 
#>        1       2       7
#> 1 0.7655 0.23043 0.00405
#> 2 0.2031 0.72514 0.07175
#> 3 0.5396 0.45909 0.00132
#> 4 0.0393 0.09419 0.86655
#> 5 0.9600 0.00936 0.03063
#> 6 0.9865 0.00724 0.00623

Case study: more than three classes

  • Our predictions are one of the three classes:
predict(train_qda, test_set) |> head() 
#> [1] 1 2 1 7 1 1
#> Levels: 1 2 7
  • The confusion matrix is therefore a 3 by 3 table:
confusionMatrix(predict(train_qda, test_set), test_set$y)$table 
#>           Reference
#> Prediction   1   2   7
#>          1 111   9  11
#>          2  10  86  21
#>          7  21  28 102

Case study: more than three classes

  • The accuracy is 0.749.

  • Note that for sensitivity and specificity, we have a pair of values for each class.

  • To define these terms, we need a binary outcome.

  • We therefore have three columns: one for each class as the positives and the other two as the negatives.

Case study: more than three classes

  • To visualize what parts of the region are called 1, 2, and 7 we now need three colors:

Case study: more than three classes

  • The accuracy for LDA, 0.629, is much worse because the model is more rigid.

Case study: more than three classes

  • Let’s try kNN
train_knn <- train(y ~ ., method = "knn", data = train_set, 
                   tuneGrid = data.frame(k = seq(15, 51, 2))) 
  • The accuracy is: 0.749!

Case study: more than three classes

  • The decision rule looks like this:

Case study: more than three classes

  • Note that one of the limitations of generative models here is due to the lack of fit of the normal assumption, in particular for class 1:

Case study: more than three classes

  • Generative models can be very powerful, but only when we are able to successfully approximate the joint distribution of predictors conditioned on each class.